{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#| eval: false\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp callback.data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Callbacks\n",
    "\n",
    "> Callbacks which work with a learner's data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from __future__ import annotations\n",
    "from fastai.basics import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import *\n",
    "from fastai.test_utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class CollectDataCallback(Callback):\n",
    "    \"Collect all batches, along with `pred` and `loss`, into `self.data`. Mainly for testing\"\n",
    "    def before_fit(self): self.data = L()\n",
    "    def after_batch(self): \n",
    "        self.data.append(self.learn.to_detach((self.xb,self.yb,self.pred,self.loss)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@delegates()\n",
    "class WeightedDL(TfmdDL):\n",
    "    \"Weighted dataloader where `wgts` is used for the training set only\"\n",
    "    def __init__(self, dataset=None, bs=None, wgts=None, **kwargs):\n",
    "        wgts = array([1.]*len(dataset) if wgts is None else wgts)\n",
    "        self.wgts = wgts/wgts.sum()\n",
    "        super().__init__(dataset=dataset, bs=bs, **kwargs)\n",
    "\n",
    "    def get_idxs(self):\n",
    "        if self.n==0: return []\n",
    "        if not self.shuffle: return super().get_idxs()\n",
    "        return list(np.random.choice(self.n, self.n, p=self.wgts))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@patch\n",
    "@delegates(Datasets.dataloaders)\n",
    "def weighted_dataloaders(self:Datasets, wgts, bs=64, **kwargs):\n",
    "    \"Create a weighted dataloader `WeightedDL` with `wgts` for the training set\"\n",
    "    xtra_kwargs = [{}] * (self.n_subsets-1)\n",
    "    return self.dataloaders(bs=bs, dl_type=WeightedDL, dl_kwargs=({'wgts':wgts}, *xtra_kwargs), **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lbls = np.random.randint(0, 2, size=(10)) # Dataset of size 10 (train=8, valid=2)\n",
    "is_valid = lambda i: i >= 8\n",
    "dblock = DataBlock(blocks=[CategoryBlock], \n",
    "    getters=[lambda i: lbls[i]], splitter=FuncSplitter(is_valid))\n",
    "dset = dblock.datasets(list(range(10)))\n",
    "item_tfms = [ToTensor()] \n",
    "wgts = range(8) # len(wgts) == 8\n",
    "dls = dset.weighted_dataloaders(bs=1, wgts=wgts, after_item=item_tfms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n"
     ]
    }
   ],
   "source": [
    "dls.show_batch() # if len(wgts) != 8, this will fail\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 160\n",
    "dsets = Datasets(torch.arange(n).float())\n",
    "dls = dsets.weighted_dataloaders(wgts=range(n), bs=16)\n",
    "learn = synth_learner(data=dls, cbs=CollectDataCallback)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, nan, None, '00:00']\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAOf0lEQVR4nO3db4xldX3H8fen/FORhqU70C1gBw2akiYuZkqhtAZBLaARfWAiqWab0qwPpNHW/lkkafUZWMU+aWzWQt1UpCEKQsBW6dbWmBjoQAGXLhTQVRdXdqhpwTZpBb59cM/acZzZe3fuvXPPD9+v5Obe8ztn5nyY2fvh3N89506qCklSe35q1gEkSetjgUtSoyxwSWqUBS5JjbLAJalRR2/kzjZv3lzz8/MbuUtJat699977VFXNrRzf0AKfn59ncXFxI3cpSc1L8s3Vxp1CkaRGWeCS1CgLXJIaZYFLUqMscElqlAUuSY2ywCWpURa4JDXKApekRm3olZiSNEvzO+6c2b73XfOmiX9Pj8AlqVEWuCQ1ygKXpEZZ4JLUKAtckhplgUtSoyxwSWqUBS5JjRpa4ElelOSeJA8keSjJh7rxk5LcleTR7n7T9ONKkg4Z5Qj8f4ALq+rVwFbg4iTnAjuA3VV1JrC7W5YkbZChBV4D3+8Wj+luBVwG7OrGdwFvnUZASdLqRpoDT3JUkvuBg8BdVXU3cEpVHQDo7k+eWkpJ0o8Z6cOsquo5YGuSE4Fbk/ziqDtIsh3YDvCyl71sPRklTcEL7YOdfhId0VkoVfUfwD8CFwNPJtkC0N0fXONrdlbVQlUtzM3NjZdWkvRDo5yFMtcdeZPkxcDrgYeB24Ft3WbbgNumlFGStIpRplC2ALuSHMWg8G+uqjuSfBW4OckVwLeAt08xpyRphaEFXlUPAmevMv7vwEXTCCVJGs4rMSWpURa4JDXKApekRlngktQoC1ySGmWBS1KjLHBJapQFLkmNssAlqVEWuCQ1ygKXpEZZ4JLUKAtckhplgUtSo0b6k2qSNEmz/HNuLyQegUtSoyxwSWqUBS5JjbLAJalRFrgkNcoCl6RGWeCS1CgLXJIaZYFLUqOGFniS05N8KcneJA8leW83/sEkTyS5v7tdOv24kqRDRrmU/lng/VV1X5ITgHuT3NWt+1hVfWR68SRJaxla4FV1ADjQPX4myV7g1GkHkyQd3hHNgSeZB84G7u6GrkzyYJIbkmxa42u2J1lMsri0tDReWknSD41c4EleCnwWeF9VPQ18HHgFsJXBEfpHV/u6qtpZVQtVtTA3Nzd+YkkSMGKBJzmGQXnfWFW3AFTVk1X1XFU9D3wCOGd6MSVJK41yFkqA64G9VXXdsvEtyzZ7G7Bn8vEkSWsZ5SyU84F3AV9Lcn839gHg8iRbgQL2Ae+eQj5J0hpGOQvlK0BWWfX5yceRJI3KKzElqVEWuCQ1ygKXpEZZ4JLUKAtckhplgUtSoyxwSWqUBS5JjbLAJalRFrgkNcoCl6RGWeCS1CgLXJIaZYFLUqMscElqlAUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGmWBS1KjLHBJatTQAk9yepIvJdmb5KEk7+3GT0pyV5JHu/tN048rSTpklCPwZ4H3V9UvAOcC70lyFrAD2F1VZwK7u2VJ0gYZWuBVdaCq7usePwPsBU4FLgN2dZvtAt46pYySpFUc0Rx4knngbOBu4JSqOgCDkgdOXuNrtidZTLK4tLQ0ZlxJ0iEjF3iSlwKfBd5XVU+P+nVVtbOqFqpqYW5ubj0ZJUmrGKnAkxzDoLxvrKpbuuEnk2zp1m8BDk4noiRpNaOchRLgemBvVV23bNXtwLbu8TbgtsnHkySt5egRtjkfeBfwtST3d2MfAK4Bbk5yBfAt4O1TSShJWtXQAq+qrwBZY/VFk40jSRqVV2JKUqMscElqlAUuSY2ywCWpURa4JDXKApekRlngktSoUS7kkV7w5nfcObN977vmTTPbt9rmEbgkNcoCl6RGWeCS1CgLXJIaZYFLUqMscElqlAUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGuWHWUkzNssP0lLbPAKXpEZZ4JLUKAtckho1tMCT3JDkYJI9y8Y+mOSJJPd3t0unG1OStNIoR+CfBC5eZfxjVbW1u31+srEkScMMLfCq+jLwvQ3IIkk6AuPMgV+Z5MFuimXTWhsl2Z5kMcni0tLSGLuTJC233gL/OPAKYCtwAPjoWhtW1c6qWqiqhbm5uXXuTpK00roKvKqerKrnqup54BPAOZONJUkaZl0FnmTLssW3AXvW2laSNB1DL6VPchNwAbA5yX7gT4ALkmwFCtgHvHt6ESVJqxla4FV1+SrD108hiyTpCPhhVuoVP9hJGp2X0ktSoyxwSWqUBS5JjbLAJalRFrgkNcoCl6RGWeCS1CgLXJIaZYFLUqMscElqlAUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGmWBS1KjLHBJapQFLkmNssAlqVEWuCQ1ygKXpEYNLfAkNyQ5mGTPsrGTktyV5NHuftN0Y0qSVhrlCPyTwMUrxnYAu6vqTGB3tyxJ2kBDC7yqvgx8b8XwZcCu7vEu4K2TjSVJGma9c+CnVNUBgO7+5LU2TLI9yWKSxaWlpXXuTpK00tTfxKyqnVW1UFULc3Nz096dJP3EWG+BP5lkC0B3f3BykSRJo1hvgd8ObOsebwNum0wcSdKoRjmN8Cbgq8CrkuxPcgVwDfCGJI8Cb+iWJUkb6OhhG1TV5WusumjCWSRJR8ArMSWpUUOPwPWTZ37HnbOOIGkEHoFLUqMscElqlAUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGmWBS1KjLHBJapQFLkmNssAlqVF+mNUIZvXhTvuuedNM9iupDR6BS1KjLHBJapQFLkmNssAlqVEWuCQ1yrNQesw/bSbpcDwCl6RGWeCS1KixplCS7AOeAZ4Dnq2qhUmEkiQNN4k58NdV1VMT+D6SpCPgFIokNWrcAi/gi0nuTbJ9tQ2SbE+ymGRxaWlpzN1Jkg4Zt8DPr6rXAJcA70ny2pUbVNXOqlqoqoW5ubkxdydJOmSsAq+q73T3B4FbgXMmEUqSNNy6CzzJ8UlOOPQYeCOwZ1LBJEmHN85ZKKcAtyY59H0+XVV/N5FUkqSh1l3gVfV14NUTzCJJOgKeRihJjbLAJalRFrgkNcoCl6RGWeCS1CgLXJIaZYFLUqMscElqlAUuSY2ywCWpURa4JDXKApekRlngktQoC1ySGmWBS1KjLHBJapQFLkmNssAlqVEWuCQ1ygKXpEaN81fpN9T8jjtnHUGSesUjcElqlAUuSY2ywCWpUWMVeJKLkzyS5LEkOyYVSpI03LoLPMlRwJ8DlwBnAZcnOWtSwSRJhzfOEfg5wGNV9fWq+l/gb4DLJhNLkjTMOKcRngp8e9nyfuCXV26UZDuwvVv8fpJH1vh+m4GnxsgzbX3O1+ds0O98fc4G/c7X52zQs3y59kcWjzTbz682OE6BZ5Wx+rGBqp3AzqHfLFmsqoUx8kxVn/P1ORv0O1+fs0G/8/U5G/Q736SyjTOFsh84fdnyacB3xosjSRrVOAX+z8CZSc5IcizwDuD2ycSSJA2z7imUqno2yZXAF4CjgBuq6qExsgydZpmxPufrczbod74+Z4N+5+tzNuh3volkS9WPTVtLkhrglZiS1CgLXJIa1YsC79Ml+UlOT/KlJHuTPJTkvd34SUnuSvJod79phhmPSvIvSe7oYbYTk3wmycPdz/C8nuX73e73uifJTUleNKt8SW5IcjDJnmVja2ZJclX3HHkkya/PKN+fdr/bB5PcmuTEPuVbtu73k1SSzbPIt1a2JL/T7f+hJB8eO1tVzfTG4A3Qx4GXA8cCDwBnzTDPFuA13eMTgH9j8FEBHwZ2dOM7gGtnmPH3gE8Dd3TLfcq2C/jt7vGxwIl9ycfg4rNvAC/ulm8GfnNW+YDXAq8B9iwbWzVL92/wAeA44IzuOXPUDPK9ETi6e3xt3/J146czOLnim8DmWeRb42f3OuDvgeO65ZPHzbZhT57D/IeeB3xh2fJVwFWzzrUsz23AG4BHgC3d2BbgkRnlOQ3YDVy4rMD7ku2nu4LMivG+5Dt09fBJDM7AuqMrpJnlA+ZXPMlXzbLyedEV1HkbnW/FurcBN/YtH/AZ4NXAvmUFvuH5Vvnd3gy8fpXt1p2tD1Moq12Sf+qMsvyIJPPA2cDdwClVdQCguz95RrH+DPhD4PllY33J9nJgCfirbornL5Mc35d8VfUE8BHgW8AB4D+r6ot9yddZK0sfnye/Bfxt97gX+ZK8BXiiqh5YsaoP+V4J/FqSu5P8U5JfGjdbHwp8pEvyN1qSlwKfBd5XVU/POg9AkjcDB6vq3llnWcPRDF42fryqzgb+i8E0QC9088mXMXiZ+nPA8UneOdtUI+vV8yTJ1cCzwI2HhlbZbEPzJXkJcDXwx6utXmVso39+RwObgHOBPwBuThLGyNaHAu/dJflJjmFQ3jdW1S3d8JNJtnTrtwAHZxDtfOAtSfYx+PTHC5N8qifZYPC73F9Vd3fLn2FQ6H3J93rgG1W1VFU/AG4BfqVH+ThMlt48T5JsA94M/EZ1r/npR75XMPif8wPdc+Q04L4kP9uTfPuBW2rgHgavojePk60PBd6rS/K7/yNeD+ytquuWrbod2NY93sZgbnxDVdVVVXVaVc0z+Dn9Q1W9sw/ZunzfBb6d5FXd0EXAv9KTfAymTs5N8pLu93wRsLdH+ThMltuBdyQ5LskZwJnAPRsdLsnFwB8Bb6mq/162aub5quprVXVyVc13z5H9DE5I+G4f8gGfY/DeFUleyeBN/qfGyjbtNxlGnOy/lMHZHo8DV884y68yePnyIHB/d7sU+BkGbx4+2t2fNOOcF/D/b2L2JhuwFVjsfn6fY/CSsU/5PgQ8DOwB/prBO/8zyQfcxGAu/gcMyuaKw2VhMD3wOIM3Oi+ZUb7HGMzXHnpu/EWf8q1Yv4/uTcyNzrfGz+5Y4FPdv737gAvHzeal9JLUqD5MoUiS1sECl6RGWeCS1CgLXJIaZYFLUqMscElqlAUuSY36P8QgsP+r6YTLAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit(1)\n",
    "t = concat(*learn.collect_data.data.itemgot(0,0))\n",
    "plt.hist(t.numpy());"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@patch\n",
    "@delegates(Datasets.weighted_dataloaders)\n",
    "def weighted_dataloaders(self:DataBlock, source, wgts, bs=64, verbose:bool=False, **kwargs):\n",
    "    \"Create a weighted dataloader `WeightedDL` with `wgts` for the dataset\"\n",
    "    dss = self.datasets(source, verbose=verbose)\n",
    "    if not hasattr(wgts, '__array__'): wgts = np.array(wgts)\n",
    "    trn_wgts = wgts[dss.splits[0]]\n",
    "    return dss.weighted_dataloaders(trn_wgts, bs=bs, after_batch=self.batch_tfms, after_item=self.item_tfms, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    }
   ],
   "source": [
    "dls = dblock.weighted_dataloaders(list(range(10)), wgts, bs=1)\n",
    "dls.show_batch()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@delegates()\n",
    "class PartialDL(TfmdDL):\n",
    "    \"Select randomly partial quantity of data at each epoch\"\n",
    "    def __init__(self, dataset=None, bs=None, partial_n=None, **kwargs):\n",
    "        super().__init__(dataset=dataset, bs=bs, **kwargs)\n",
    "        self.partial_n = min(partial_n, self.n) if partial_n else None\n",
    "\n",
    "    def get_idxs(self):\n",
    "        if self.partial_n is None: return super().get_idxs()\n",
    "        return list(np.random.choice(self.n, self.partial_n, replace=False))\n",
    "\n",
    "    def __len__(self):\n",
    "        if self.partial_n is None: return super().__len__()\n",
    "        return self.partial_n//self.bs + (0 if self.drop_last or self.partial_n%self.bs==0 else 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@patch\n",
    "@delegates(Datasets.dataloaders)\n",
    "def partial_dataloaders(self:FilteredBase, partial_n, bs=64, **kwargs):\n",
    "    \"Create a partial dataloader `PartialDL` for the training set\"\n",
    "    xtra_kwargs = [{}] * (self.n_subsets-1)\n",
    "    return self.dataloaders(bs=bs, dl_type=PartialDL, dl_kwargs=({'partial_n':partial_n}, *xtra_kwargs), **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls = dsets.partial_dataloaders(partial_n=32, bs=16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert len(dls[0])==2\n",
    "for batch in dls[0]:\n",
    "    assert len(batch[0])==16"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev import nbdev_export\n",
    "nbdev_export()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
